import os
import sys
sys.path.append('./')
import glob
import argparse

import json
import joblib
import torch
import numpy as np
from smplx import SMPL, SMPLX
from tqdm import tqdm

import config as _C
from utils import rotation as r
from utils.eval_utils import (batch_align_by_pelvis, 
                              batch_compute_similarity_transform_torch,)

test_sequences = ["P01_Seq01", 
                  "P02_Seq02", 
                  "P02_Seq03", 
                  "P03_Seq03", 
                  "P05_Seq01"
]


fps = 30
metric = {
    "mpjpe": 1e3, # mm
    "pa_mpjpe": 1e3, # mm
    "pve": 1e3, # mm
    "pa_pve": 1e3, # mm
    "accel": fps ** 2,  # m/s^2
    "jitter": fps ** 3 / 1e2,  # 10^2m/s^3
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', default='camerahmr')
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    smpl_model = SMPL(_C.SMPL_MODEL_DIR, gender="neutral", num_betas=10).eval().to(device)

    smplx2smpl = torch.from_numpy(
        joblib.load(_C.SMPLX2SMPL_PTH)["matrix"]
    ).float().to(device)

    eval_results = dict(
        pve=[],
        pa_pve=[],
        mpjpe=[],
        pa_mpjpe=[],
        accel=[],
        jitter=[], 
        jitter_gt=[], 
        vid_names=[]
    )

    with torch.no_grad():
        for sequence in tqdm(test_sequences):
            annotation_pth = os.path.join(_C.DATA_BASE_DIR, sequence, f"{sequence}_annot.pkl")
            annotation = joblib.load(annotation_pth)
            
            gender = annotation["gender"]
            intrinsics = annotation["intrinsics"]
            extrinsics = annotation["extrinsics"]

            gt_params = {k: torch.from_numpy(v).float().to(device) for k, v in annotation.items() if k in ["body_pose", "global_orient", "betas"] }
            _C.SEQUENCE_NAME = sequence

            max_frame_id = len(gt_params["betas"])
            for camera_name in _C.EXO_CAMERA_NAMES:
                camera_i = int(camera_name[-1]) - 1
                if not camera_i in annotation["camera_idxs"]:
                    continue

                prediction_pth = "path-to-your-prediction-path"
                pred = joblib.load(prediction_pth)

                # Convert world-coordinate GT to camera coordinate
                R = torch.from_numpy(extrinsics.reshape(-1, 3, 4)[:, :3, :3]).float().to(device)
                R = R / R.norm(dim=-1)
                gt_rotmat_world = r.axis_angle_to_matrix(gt_params["global_orient"])
                gt_rotmat_cam = R @ gt_rotmat_world
                gt_global_orient_cam = r.matrix_to_axis_angle(gt_rotmat_cam)
                gt_params["global_orient"] = gt_global_orient_cam.clone()

                bboxes = annotation["bboxes"]
                valid_idxs = np.wherer(bboxes[:, -1] != -1)[0]

                # Construct ground truth body
                smplx_model = SMPLX(_C.SMPLX_MODEL_DIR, gender=gender, 
                                num_betas=11, 
                                batch_size=valid_idxs.shape[0]
                ).eval().to(device)
                
                gt_body = smplx_model(
                    **{k: v[valid_idxs] for k, v in gt_params.items()}
                )

                gt_verts = torch.matmul(smplx2smpl, gt_body.vertices)
                gt_joints = torch.matmul(smpl_model.J_regressor, gt_verts)[:, :24]
                

                # Build prediction
                pred_params = dict(
                    global_orient=torch.from_numpy(pred["global_orient"]).float().to(device),
                    body_pose=torch.from_numpy(pred["body_pose"]).float().to(device),
                    betas=torch.from_numpy(pred["betas"]).float().to(device)
                )

                # Or any method that uses SMPL_X model
                if args.model.lower() == "bedlam-cliff":
                    num_betas = pred_params["betas"].size(-1)
                    body_model = SMPLX(_C.SMPLX_MODEL_DIR, gender="neutral", 
                                num_betas=num_betas, 
                                batch_size=valid_idxs.shape[0]
                    ).eval().to(device)

                    pred_body = body_model(
                        **pred_params
                    )

                    pred_verts = torch.matmul(smplx2smpl, pred_body.vertices)
                    pred_joints = torch.matmul(smpl_model.J_regressor, pred_verts)[:, :24]

                else:
                    with torch.no_grad():
                        pred_body = smpl_model(
                            **pred_params
                        )

                    pred_verts = pred_body.vertices
                    pred_joints = pred_body.joints[:, :24]

            # Use pelvis IDX 0 for alignment
            gt_joints, pred_joints, gt_verts, pred_verts = batch_align_by_pelvis(
                [gt_joints.cpu(), pred_joints.cpu(), gt_verts.cpu(), pred_verts.cpu()], [0])
            
            # MPJPE
            mpjpe = (gt_joints - pred_joints).norm(dim=-1).mean(1) * metric["mpjpe"]
            
            # PA-MPJPE
            S1_hat = batch_compute_similarity_transform_torch(pred_joints, gt_joints)
            pa_mpjpe = (gt_joints - S1_hat).norm(dim=-1).mean(1) * metric["pa_mpjpe"]

            # PVE
            pve = (gt_verts - pred_verts).norm(dim=-1).mean(1) * metric["pve"]
            
            # PA-PVE
            S1_hat = batch_compute_similarity_transform_torch(pred_verts, gt_verts)
            pa_pve = (gt_verts - S1_hat).norm(dim=-1).mean(1) * metric["pa_pve"]

            # Accel
            gt_accel = (gt_joints[2:] - 2 * gt_joints[1:-1] + gt_joints[:-2])
            pred_accel = (pred_joints[2:] - 2 * pred_joints[1:-1] + pred_joints[:-2])
            accel = (gt_accel - pred_accel).norm(dim=-1).mean(1) * metric["accel"]

            # Jitter
            pred_jitter = (pred_joints[3:] - 3 * pred_joints[2:-1] + 3 * pred_joints[1:-2] - pred_joints[:-3])
            jitter = pred_jitter.norm(dim=-1).mean(1) * metric["jitter"]

            # Jitter (GT)
            gt_jitter = (gt_joints[3:] - 3 * gt_joints[2:-1] + 3 * gt_joints[1:-2] - gt_joints[:-3])
            gt_jitter = gt_jitter.norm(dim=-1).mean(1) * metric["jitter"]

            eval_results["pve"].append(pve)
            eval_results["pa_pve"].append(pa_pve)
            eval_results["mpjpe"].append(mpjpe)
            eval_results["pa_mpjpe"].append(pa_mpjpe)
            eval_results["accel"].append(accel)
            eval_results["jitter"].append(jitter)
            eval_results["jitter_gt"].append(gt_jitter)
            

    msg = f"Model {args.model}"
    for k, v in eval_results.items():
        eval_results[k] = np.concatenate(v, axis=0)
        if k != "vid_names":
            msg += f"  | {k}: {eval_results[k].mean():1f} mm"
    print(msg)
    print()